import itertools
from typing import List, Tuple, Any

import gym.spaces as spaces
import networkx as nx
import numpy as np

from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv
from centralized_verification.envs.unused.octothorpe_grid_world import OctothorpeGridWorld

"""
Data types in this file:
A centralized shield/action space is a n-dimensional boolean array (usually called all_safe_actions)
The value of all_safe_actions[agent1_action, agent2_action, ..., agentn_action] denotes whether said action is safe

A factorized shield/action space is a tuple of 1-dimensional boolean arrays (usually factorized_actions)
factorized_actions[0][a] denotes whether it is safe for agent 0 to take action a, 
no matter what actions all other agents take (as long as all other agents take one of the actions 
allowed by their respective array in factorized_actions)
"""


def swap_tuple(input, index1, index2):
    """Return a new tuple with the indices swapped, without modifying the original tuple"""
    lst = list(input)
    lst[index1], lst[index2] = lst[index2], lst[index1]
    return tuple(lst)


def is_factorized_action_space_safe(all_safe_actions: np.ndarray, factorized_actions: Tuple[np.ndarray]):
    """
    Is the cartesian product of factorized_actions a subset of all_safe_actions?
    That is, are all actions in factorized_actions safe?
    """
    action_space_selector = np.logical_and.reduce(np.meshgrid(*factorized_actions, indexing='ij'))
    return np.all(all_safe_actions[action_space_selector])


def maximize_agent_safe(all_safe_actions: np.ndarray, factorized_actions: Tuple[np.ndarray], agent_num):
    """
    For a single agent, find the maximal set of actions that this agent can take,
    assuming all other agents select one from factorized_actions
    :return: Nothing, sorry this operates using side effects on factorized_actions
    """

    # Move this agent to index 0
    safe_actions_transposed = all_safe_actions.swapaxes(0, agent_num)
    factorized_actions_transposed = swap_tuple(factorized_actions, 0, agent_num)

    # Find all factorized actions of other agents
    _, *rest_factorized_actions = factorized_actions_transposed

    # For any given (n-1) dimensional slice of the action space---the actions taken by other agents---
    # which values are possible if the other agents only consider rest_factorized_actions?
    action_space_slice = np.logical_and.reduce(np.meshgrid(*rest_factorized_actions, indexing='ij'))

    # safe_actions_transposed[n, action_space_slice] tells us whether for a given action n that _this_ agent takes,
    # whether each action in factorized_actions_transposed which the other agents would take are safe.
    # They must all be safe in order for n to be a safe action itself
    this_agent_safe_actions = safe_actions_transposed[:, action_space_slice].all(axis=1)

    assert this_agent_safe_actions.sum() > 0, "There should be at least one safe action for every agent"

    # Replace the factorized actions with the newly-expanded safe actions
    factorized_actions[agent_num][:] = this_agent_safe_actions


def decentralize_actions(all_safe_actions: np.ndarray, starting_safe_action: List[int], agent_priority: List[int]):
    """
    Find some maximally-permissive factorized action set that contains starting_safe_action
    Algorithm is deterministic (and fairly greedy, so it might not necessarily find the best action)
    :param agent_priority: Generally, the earlier an agent appears in this list, the more actions will be available
    in its portion of the factorized action space. Assumed to contain all agents exactly once
    """
    safe_action_space = all_safe_actions.shape
    starting_safe_action = tuple(starting_safe_action)
    assert all_safe_actions[starting_safe_action], "Starting action must be safe"

    # Start by building a factorized action space that only contains starting_safe_action
    factorized_actions = tuple(np.zeros((safe_action_space[i]), dtype=bool) for i in range(len(agent_priority)))
    for i in range(len(agent_priority)):
        factorized_actions[i][starting_safe_action[i]] = True

    # Maximize the actions each agent can take (in priority order)
    for agent_num in agent_priority:
        maximize_agent_safe(all_safe_actions, factorized_actions, agent_num)

    return factorized_actions


def agents_interact(all_safe_actions: np.ndarray, agent1: int, agent2: int) -> bool:
    # Assume agent1 and agent2 are different
    # Look at all 2D slices along the agent1-agent2 axis of the potentially high-dimensional matrix of safe actions
    # If all of them are cartesian, the agents do not interact

    # For simplicity, the agents we care about are always the last two
    safe_actions_transposed = np.moveaxis(all_safe_actions, (agent1, agent2), (-2, -1))

    # For each state, what are all the actions agent 1 can possibly ever take
    max_actions_agent1 = np.logical_or.reduce(safe_actions_transposed, axis=-1)

    all_agent1_actions_available_for_this_agent2_action = np.logical_and.reduce(
        safe_actions_transposed == np.expand_dims(max_actions_agent1, -1), axis=-2)
    no_agent1_actions_available_for_this_agent2_action = np.logical_not(
        np.logical_or.reduce(safe_actions_transposed, axis=-2))

    # The 2D slice is cartesian if for all agent 2 actions,
    # The actions available to agent 1 are either equal to each other, or are all unsafe
    # (and thus agent2's action is unsafe)
    all_actions_invar_or_invalid = np.logical_or(all_agent1_actions_available_for_this_agent2_action,
                                                 no_agent1_actions_available_for_this_agent2_action)
    agents_are_independent = np.all(all_actions_invar_or_invalid)

    return not agents_are_independent


def interaction_graph(all_safe_actions: np.ndarray):
    """
    Given a centralized shield, which agents' actions affect other agents actions?
    """
    interacting_agents = nx.Graph()

    # All pairs of agents
    for agent1 in range(len(all_safe_actions.shape)):
        for agent2 in range(agent1):
            if agents_interact(all_safe_actions, agent1, agent2):
                interacting_agents.add_edge(agent1, agent2)

    return interacting_agents


def is_safe_action(environment: MultiAgentSafetyEnv, current_env_state: Any, proposed_action) -> bool:
    *_, safety = environment.step(current_env_state, proposed_action)
    return safety


def env_to_action_set(environment: MultiAgentSafetyEnv, current_env_state: Any) -> np.ndarray:
    # Creates the joint safe action space for all discrete-actioned agents in the environment
    action_space_dims = tuple(
        space.n for space in environment.agent_actions_spaces() if isinstance(space, spaces.Discrete))

    safe_action_space = np.zeros(action_space_dims, dtype=bool)
    for joint_action in itertools.product(*[range(i) for i in action_space_dims]):
        safe_action_space[joint_action] = is_safe_action(environment, current_env_state, joint_action)

    return safe_action_space


if __name__ == '__main__':
    x = np.array(
        [[[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]],
         [[1, 1, 0, 0],
          [1, 1, 0, 0],
          [0, 0, 0, 0]]], dtype=bool
    )

    print(decentralize_actions(x, (0, 0, 0), (0, 1, 2)))
    print(decentralize_actions(x, (0, 0, 0), (1, 0, 2)))
    y = interaction_graph(x)
    print(list(y.edges))

    o = OctothorpeGridWorld()
    safe_action_set = env_to_action_set(o, (0, (0, 2, 2, 1, 7, 0)))
    print(decentralize_actions(safe_action_set, (0, 0, 0), (0, 1, 2)))
    print(decentralize_actions(safe_action_set, (0, 0, 0), (2, 1, 0)))
    print(list(interaction_graph(safe_action_set).edges))
